#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import logging as _logging
from collections import OrderedDict as _OrderedDict
from typing import Any as _Any
from typing import Callable as _Callable
from typing import Dict as _Dict
from typing import NewType as _NewType
from typing import Optional as _Optional
from typing import Tuple as _Tuple
from typing import Union as _Union

import cattrs as _cattrs
import torch as _torch
from attr import define as _define
from attr import field as _field
from attrs import validators as _validators

from coremltools.optimize.torch._utils.k_means import KMeansConfig as _KMeansConfig
from coremltools.optimize.torch._utils.k_means import (
    KMeansSupportedModulesRegistry as _KMeansSupportedModulesRegistry,
)
from coremltools.optimize.torch._utils.k_means import ParallelKMeans as _ParallelKMeans
from coremltools.optimize.torch._utils.k_means import SequentialKMeans as _SequentialKMeans
from coremltools.optimize.torch._utils.optimizer_utils import (
    _ConfigToOptimizerRegistry,
    _ModuleToOptConfigRegistry,
)
from coremltools.optimize.torch._utils.report_utils import (
    compute_post_training_report as _compute_post_training_report,
)
from coremltools.optimize.torch._utils.torch_utils import (
    maybe_convert_str_to_dtype as _maybe_convert_str_to_dtype,
)
from coremltools.optimize.torch._utils.torch_utils import (
    maybe_convert_str_to_mod_type as _maybe_convert_str_to_mod_type,
)
from coremltools.optimize.torch._utils.validation_utils import (
    validate_param_config as _validate_param_config,
)
from coremltools.optimize.torch.base_model_optimizer import (
    BasePostTrainingModelOptimizer as _BasePostTrainingModelOptimizer,
)
from coremltools.optimize.torch.base_model_optimizer import _Report
from coremltools.optimize.torch.optimization_config import (
    ModuleOptimizationConfig as _ModuleOptimizationConfig,
)
from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig
from coremltools.optimize.torch.optimization_config import (
    PalettizationGranularity,
    _structure_from_dict_hook_factory,
)

_logger = _logging.getLogger(__name__)


@_define
class ModulePostTrainingPalettizerConfig(_ModuleOptimizationConfig):
    """
    Configuration class for specifying global and module-level palettization options for
    :py:class:`PostTrainingPalettizerConfig` algorithm.

    Args:
        n_bits (:obj:`int`): Number of bits to use for palettizing the weights. Defaults to ``4``.
        lut_dtype (:py:class:`torch.dtype`): The dtype to use for representing each element in lookup tables.
            When value is ``None``, no quantization is performed. Supported values are :py:class:`torch.int8` and
            :py:class:`torch.uint8`. Defaults to ``None``.
        granularity (:py:class:`PalettizationGranularity`) – Granularity for palettization.
            One of ``per_tensor`` or ``per_grouped_channel``. Defaults to ``per_tensor``.
        group_size (:obj:`int`): Specify the number of channels in a group.
            Only effective when granularity is ``per_grouped_channel``.
        channel_axis (:obj:`int`): Specify the channel axis to form a group of channels.
            Only effective when granularity is ``per_grouped_channel``. Defaults to output channel axis.
        cluster_dim (:obj:`int`): The dimension of centroids for each lookup table.
            The centroid is a scalar by default. When ``cluster_dim > 1``, it indicates 2-D clustering,
            and each ``cluster_dim`` length of weight vectors along the output channel are palettized
            using the same 2-D centroid. The length of each entry in the lookup tables is equal to ``cluster_dim``.
        enable_per_channel_scale (:obj:`bool`): When set to ``True``, weights are normalized along the output channels
            using per-channel scales before being palettized. This is not supported with ``cluster_dim > 1``.
        enable_fast_kmeans_mode (:obj:`bool`): When turned on, will round the weights before clustering if data is in fp16 range.
            If weight dtype is fp32, weights are cast to fp16 and then rounded. This is not supported with ``cluster_dim > 1``.
            Defaults to True.
        rounding_precision (:obj:`int`): The number of decimal places to set for rounding, when `enable_fast_kmeans_mode` is enabled.
            Choose a lower precision for faster processing, at the cost of coarser approximation. Defaults to 4.


    This class supports two different configurations to structure the palettization:

    1. **Per-tensor palettization**:  This is the default configuration where the whole tensor shares a single lookup
    table. The ``granularity`` is set to ``per_tensor``, and ``group_size`` is ``None``.

    2. **Per-grouped-channel palettization**: In this configuration, the number of channels ``group_size`` along
    ``channel_axis`` share the same lookup table. For example, for a weight matrix of shape ``(16, 25)``, if we provide
    ``group_size = 8``, the shape of the lookup table would be ``(2, 2^n_bits)``.

    .. note::
        Grouping is currently only supported along either the input or output channel axis.
    """

    n_bits: _Optional[int] = _field(
        default=4, validator=_validators.optional(_validators.instance_of(int))
    )
    lut_dtype: _torch.dtype = _field(
        default=None,
        converter=lambda val: _maybe_convert_str_to_dtype(val) if val else val,
        validator=_validators.optional(
            [
                _validators.instance_of(_torch.dtype),
                _validators.in_([_torch.int8, _torch.uint8]),
            ]
        ),
    )
    granularity: PalettizationGranularity = _field(
        default="per_tensor",
        converter=PalettizationGranularity,
        validator=_validators.in_(PalettizationGranularity),
    )
    group_size: _Optional[int] = _field(
        default=None, validator=_validators.optional(_validators.instance_of(int))
    )
    channel_axis: int = _field(
        default=0,
        validator=_validators.optional([_validators.instance_of(int), _validators.in_([0, 1])]),
    )
    cluster_dim: _Optional[int] = _field(
        default=None, validator=_validators.optional(_validators.instance_of(int))
    )
    enable_per_channel_scale: _Optional[bool] = _field(
        default=False, validator=_validators.optional(_validators.instance_of(bool))
    )
    enable_fast_kmeans_mode: _Optional[bool] = _field(
        default=True, validator=_validators.optional(_validators.instance_of(bool))
    )
    rounding_precision: _Optional[int] = _field(
        default=4, validator=_validators.optional(_validators.instance_of(int))
    )

    @group_size.validator
    def per_grouped_channel_granularity(self, attribute, value):
        if self.granularity == PalettizationGranularity.per_grouped_channel:
            assert (
                value is not None
            ), "group_size has to be specified along with per_grouped_channel granularity."
            assert value > 0, "group_size should be greater than zero"
        else:
            assert value is None, "group_size can't be specified along with per_tensor granularity."

    @cluster_dim.validator
    def no_per_channel_scale(self, attribute, value):
        if value and value > 1:
            assert (
                self.enable_per_channel_scale == False
            ), f"Enabling per_channel_scale is not supported for cluster_dim={value} larger than 1"


_ModuleTypeConfigType = _NewType(
    "ModuleTypeConfigType",
    _Dict[_Union[_Callable, str], _Optional[ModulePostTrainingPalettizerConfig]],
)

@_ModuleToOptConfigRegistry.register_module_cfg(ModulePostTrainingPalettizerConfig)
@_define
class PostTrainingPalettizerConfig(_OptimizationConfig):
    """
    Configuration class for specifying how different submodules of a model
    should be post-training palettized by :py:class:`PostTrainingPalettizer`.

    Args:
        global_config (:py:class:`ModulePostTrainingPalettizerConfig`): Config to be applied globally
            to all supported modules.
        module_type_configs (:obj:`dict` of :obj:`str` to :py:class:`ModulePostTrainingPalettizerConfig`):
            Module type configs applied to a specific module class, such as :py:class:`torch.nn.Linear`.
            The keys can be either strings or module classes.
        module_name_configs (:obj:`dict` of :obj:`str` to :py:class:`ModulePostTrainingPalettizerConfig`):
            Module name configs applied to specific modules. This can be a dictionary with module names pointing to their
            corresponding :py:class:`ModulePostTrainingPalettizerConfig`.
    """

    global_config: _Optional[ModulePostTrainingPalettizerConfig] = _field(
        default=None,
        validator=_validators.optional(_validators.instance_of(ModulePostTrainingPalettizerConfig)),
    )
    module_type_configs: _ModuleTypeConfigType = _field(
        factory=_OrderedDict,
        validator=_validators.deep_mapping(
            key_validator=_validators.instance_of((str, _Callable)),
            value_validator=_validators.optional(
                _validators.instance_of(ModulePostTrainingPalettizerConfig)
            ),
            mapping_validator=_validators.instance_of(dict),
        ),
    )
    module_name_configs: _Dict[str, _Optional[ModulePostTrainingPalettizerConfig]] = _field(
        factory=_OrderedDict,
        validator=_validators.deep_mapping(
            key_validator=_validators.instance_of(str),
            value_validator=_validators.optional(
                _validators.instance_of(ModulePostTrainingPalettizerConfig)
            ),
            mapping_validator=_validators.instance_of(dict),
        ),
    )

    def __attrs_post_init__(self):
        if (
            self.global_config is None
            and len(self.module_type_configs) == 0
            and len(self.module_name_configs) == 0
        ):
            self.global_config = ModulePostTrainingPalettizerConfig()
        self.module_type_configs = {
            _maybe_convert_str_to_mod_type(key): val
            for key, val in self.module_type_configs.items()
        }

    @classmethod
    def from_dict(cls, config_dict: _Dict[str, _Any]) -> "PostTrainingPalettizerConfig":
        super().from_dict(config_dict)
        converter = _cattrs.Converter(forbid_extra_keys=True)
        converter.register_structure_hook(
            _ModuleTypeConfigType,
            _structure_from_dict_hook_factory(ModulePostTrainingPalettizerConfig),
        )
        return converter.structure_attrs_fromdict(config_dict, cls)


@_ConfigToOptimizerRegistry.register_config(PostTrainingPalettizerConfig)
class PostTrainingPalettizer(_BasePostTrainingModelOptimizer):
    """
    Perform post-training palettization on a torch model. Post palettization, all the weights in supported
    layers point to elements in a lookup table after performing a k-means operation.

    Example:

            .. code-block:: python

                import torch.nn as nn
                from coremltools.optimize.torch.palettization import (
                    PostTrainingPalettizerConfig,
                    PostTrainingPalettizer,
                )

                model = nn.Sequential(
                    OrderedDict(
                        {
                            "conv": nn.Conv2d(1, 20, (3, 3)),
                            "relu1": nn.ReLU(),
                            "conv2": nn.Conv2d(20, 20, (3, 3)),
                            "relu2": nn.ReLU(),
                        }
                    )
                )

                # initialize the palettizer
                config = PostTrainingPalettizerConfig.from_dict(
                    {
                        "global_config": {
                            "n_bits": 4,
                        },
                    }
                )

                ptpalettizer = PostTrainingPalettizer(model, config)
                palettized_model = ptpalettizer.compress()

    Args:
        model (:obj:`torch.nn.Module`): Module to be compressed.
        config (:py:class:`PostTrainingPalettizerConfig`): Config that specifies how
            different submodules in the model will be palettized.
    """

    _supported_modules: _Tuple = _KMeansSupportedModulesRegistry.get_supported_modules()

    def __init__(self, model: _torch.nn.Module, config: PostTrainingPalettizerConfig = None):
        config = PostTrainingPalettizerConfig() if config is None else config
        super().__init__(model, config)

    def compress(self, num_kmeans_workers: int = 1, inplace: bool = False) -> _torch.nn.Module:
        """
        The compress method performs a `k-means` operation on all supported modules.

        Args:
            num_kmeans_workers (:obj:`int`): Number of worker processes used for
                performing post-training palettization. It is recommended to use more than one worker
                process to parallelize the clustering, especially when multiple CPUs are available.
                Defaults to ``1``.
            inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and
                the original module is mutated, otherwise a copy of the model is mutated and returned.
                Defaults to ``False``.
        """
        self._model = super().compress(inplace=inplace)
        kmeans_config_dict = dict()
        for name, submodule in self._model.named_modules():
            submod_config = self._config.get_module_config(name, submodule)
            if submod_config is None:
                continue

            k_means_module_cls = _KMeansSupportedModulesRegistry.get_kmeans_module(submodule)
            if k_means_module_cls is None:
                continue

            for param_name in k_means_module_cls.parameter_names:
                # Validate configuration for parameter
                param = submodule.get_parameter(param_name)
                updated_config = _validate_param_config(
                    name + "." + param_name,
                    param,
                    submodule,
                    submod_config,
                    ["palettization_group_size", "palettization_cluster_dim"],
                )
                if not updated_config:
                    continue

                if name not in kmeans_config_dict:
                    kmeans_config_dict[name] = {}

                kmeans_config_dict[name][param_name] = _KMeansConfig(
                    n_bits=updated_config.n_bits,
                    axis=updated_config.channel_axis,
                    lut_dtype=updated_config.lut_dtype,
                    block_size=updated_config.group_size,
                    cluster_dim=updated_config.cluster_dim,
                    enable_per_channel_scale=updated_config.enable_per_channel_scale,
                    enable_fast_kmeans_mode=updated_config.enable_fast_kmeans_mode,
                    rounding_precision=updated_config.rounding_precision,
                )

        if num_kmeans_workers > 1:
            return _ParallelKMeans.cluster_weights(
                self._model, kmeans_config_dict, num_workers=num_kmeans_workers
            )
        else:
            return _SequentialKMeans.cluster_weights(self._model, kmeans_config_dict)

    def report(self) -> _Report:
        return _compute_post_training_report(
            self._uncompressed_model,
            self._model,
            supported_modules=self._supported_modules,
        )
